{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "191fbf06-18c8-49b8-b1b1-2edb7fca64d7",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Visualization and explanation of the synthetic datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2207ec9-bca0-4dd6-b68c-7cb77405140a",
   "metadata": {},
   "source": [
    "In this work, we use two synthetic datasets ```FourBars``` and ```ColorBar```, to test our algorithms with synthetic data. In this notebooks the generators are presented and the datasets are visualized."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c48b7863",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "import os\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c0129cae",
   "metadata": {},
   "outputs": [],
   "source": [
    "# OTHER IMPORTS\n",
    "import torch\n",
    "\n",
    "# SETUP for the figures.\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('seaborn-paper')\n",
    "\n",
    "#plt.rc('font', family='Avenir', serif='Computer Modern')\n",
    "plt.rc('text', usetex=True)\n",
    "plt.rc('xtick', labelsize=8)\n",
    "plt.rc('ytick', labelsize=8)\n",
    "plt.rc('axes', labelsize=8)\n",
    "plt.rc('axes', titlesize=8)\n",
    "plt.rc('legend',fontsize=8) # using a size in points\n",
    "plt.rcParams['axes.linewidth'] = 1.15\n",
    "#plt.rcParams[\"figure.figsize\"] = (6.,4.1631189606246317)\n",
    "\n",
    "rc = {\"axes.spines.left\" : True,\n",
    "      \"axes.spines.right\" : True,\n",
    "      \"axes.spines.bottom\" : True,\n",
    "      \"axes.spines.top\" : True,\n",
    "      \"xtick.bottom\" : True,\n",
    "      \"xtick.labelbottom\" : True,\n",
    "      \"ytick.labelleft\" : True,\n",
    "      \"ytick.left\" : True\n",
    "     }\n",
    "\n",
    "plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    #\"font.sans-serif\": [\"Helvetica\"]}\n",
    "    \"text.usetex\": True,\n",
    "    \"pgf.texsystem\": \"pdflatex\",\n",
    "    \"pgf.preamble\": r\"\\usepackage{bm}\",\n",
    "    \"font.family\": \"serif\",\n",
    "    # Use LaTeX default serif font.\n",
    "    \"font.serif\": [\"Times\"],\n",
    "    # Use specific cursive fonts.\n",
    "    \"font.cursive\": [\"Comic Neue\", \"Comic Sans MS\"],\n",
    "\n",
    "})\n",
    "plt.rcParams.update(rc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "a71c7e8b-4337-486b-80bb-ca3f6aa46f12",
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.four_bars import FourBars, ColorBar "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4dcf56af",
   "metadata": {},
   "source": [
    "## Initialize the datasets and plot random samples and traversals"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6ce7c9b-cd76-45bf-9da8-ed496d6c6b48",
   "metadata": {},
   "source": [
    "We implemented two datasets, FourBars and ColorBar. (see Figure 3a, 3b in the paper) FourBars has 4 factors of variation, whereas ColorBars only has 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "id": "af7a4732-75f2-461e-89d4-5724eb78a1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Use this flag to switch between datasets.\n",
    "use_colorbar = True\n",
    "n_intervals = 11\n",
    "if use_colorbar:\n",
    "    syn_dataset = ColorBar(n_intervals, nonlin_colors=False)\n",
    "    num_factors = 3\n",
    "else:\n",
    "    syn_dataset = FourBars(n_intervals)\n",
    "    num_factors = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 142,
   "id": "e044bf23",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use the sample_factors function, to draw n_gen valid factor combinations.\n",
    "n_gen = 8\n",
    "facts = syn_dataset.sample_factors(n_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 143,
   "id": "68ea3bbb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 9,  8,  4],\n",
       "       [ 0,  1,  8],\n",
       "       [ 5, 10,  1],\n",
       "       [ 1,  8,  6],\n",
       "       [ 7,  0,  0],\n",
       "       [ 3,  5,  8],\n",
       "       [ 6,  3,  3],\n",
       "       [ 1,  3,  3]])"
      ]
     },
     "execution_count": 143,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "facts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eaa00ad2-9961-44f9-8580-d6a898b961a3",
   "metadata": {},
   "source": [
    "... draws samples values of three / four factors in the range ```[0, n_intervals-1]```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 144,
   "id": "674666dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Then use sample_observations_from_factors to generate the actual images\n",
    "batch = syn_dataset.sample_observations_from_factors(facts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 145,
   "id": "8fbf5232",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 146,
   "id": "44cea02c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAABAsAAACOCAYAAABXCIHPAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADedJREFUeJzt3U2MXedZB/DnjS3bqcdk8JiFV1WdNIVCF7GmalRYOqIVdEFkqIBdVU1Agg0SiiJQvwSiZsGGla1KbPgQaTcUFoVaYlciJYpAgChu5C5ANRKZ0SRxG9eN9XbhW+Q3jO89d+ac+3Ge308aKZHOOc9z7//x9fGje++UWmsAAAAA/Mgjy24AAAAAWC2WBQAAAEDDsgAAAABoWBYAAAAADcsCAAAAoGFZAAAAADQsCwAAAICGZQEAAADQON7nxUopxyLifRGxHxG1z2vTuxIRmxHx7Vrrvd4uagbWhfzpfQbkv1bkn5v8c3MPkJv86TwDvS4L4v6AfKvnazKs90fEaz1ezwysF/nT5wzIf/3IPzf55+YeIDf5M3MG+l4W7EdE3LhxI86ePdvzpenT3t5ePPnkkxGTzHpkBtbA0Pl/5lc/E6dPne58Uq0W0EdVSpnr+O/e+W584a++ENHvDPjzvyYGeg1Y+fx/5tyfLqzWv73+2wurNa+s+XOfe8Dc5M88M9D3sqBGRJw9eza2trZ6vjQD6ftfaWZgvQyS/+lTp2Pj1Eb3kywLjmzeZcED+nzy/flfP6nyPxbdl5hHtarPwbukyp//xz1gbvJn5gz4gkMAAACgYVkAAAAANKZ+DKGUshkRlyLibES8WGvt+7MtrDD5YwZyk39u8s9N/piB3ORPxOx3FrwQEdcnPzvDt8OKkT9mIDf55yb/3OSPGchN/sz8gsOLky3Sfinlww87qJSyE/eH6FifzbF0nfKPMAMj5jUgN/nnJv/c5I8ZyE3+9POdBbXWa7XW7bj/VhUSMgO5yT83+ecm/9zkjxnITf7jNuudBTcnn1c5GxEvL6AfVov8MQO5yT83+ecmf8xAbvJn5rLgSkT8yuS/rw3cC6tH/piB3OSfm/xzkz9mIDf5M31ZUGu9GYYjLfljBnKTf27yz03+mIHc5E9ET99ZAAAAAIzHrI8hAMzt+p9djxNxovPxNeqA3eRQosx1/N24O1AnADC8b3zt3+c+5/t33hmgk4OdPDX/P7M++rGfHqCTcbr9tb+f+5x6584AnRysnDo19zkbH/v5ATo5Gu8sAAAAABqWBQAAAEDDsgAAAABoWBYAAAAADcsCAAAAoGFZAAAAADQsCwAAAICGZQEAAADQsCwAAAAAGpYFAAAAQMOyAAAAAGhYFgAAAAANywIAAACgcXyIi/7d374VZ86c6Hz8H/7B60O0kcrv/f65uY5/6623BuoEIt77offGo8cf7X5CHa6XNMp8h7/9ztsR/zpMK7Cq/rv+7rJbAHryuU/9+dzn/O+t/QE6OdhPnN+c+5x/+M4fDdDJOP3Ppz499znv3Lo1QCcHO37+/NznPPGd/xqgk6PxzgIAAACgYVkAAAAANCwLAAAAgIZlAQAAANCY+gWHpZSLEfHJiLhQa/3lxbTEqpA/ZiA3+ecm/9zkjxnITf5EdHhnQa31+Yh4eTIwByql7JRSXomI6302x/J1yT/CDIyZ14Dc5J+b/HOTP2YgN/kzdVlQa331gf+9OeW4a7XW7Yi41FdjLF/X/CfHmoER8hqQm/xzk39u8scM5CZ/Irp/Z8F+rXVxv5iUVSN/zEBu8s9N/rnJHzOQm/wTm7ksKKVcrrVeW0QzrB75YwZyk39u8s9N/piB3OTP1GVBKWUnIj5ZSrlaSrm8oJ5YEfLHDOQm/9zkn5v8MQO5yZ+IGb8NYbJJsk1KSv6Ygdzkn5v8c5M/ZiA3+RPR/TsLAAAAgCSmvrPgsPb27sXdu+90Pv5f/vnOEG2ksrvb/fmOiLh9+95AnUDEF//xi7G1tbXsNphid3c3vnTuS4Nc+3v/9FI8+thjnY9/8y/+cpA+MvmxX/+1uY7/3htvDNQJwGI8unFi7nNOb5waoJODHaY/uisbp+c+55GNjQE6Odhh+ltF3lkAAAAANCwLAAAAgIZlAQAAANCwLAAAAAAalgUAAABAw7IAAAAAaFgWAAAAAA3LAgAAAKBhWQAAAAA0LAsAAACAhmUBAAAA0LAsAAAAABrHh7joBz5wIh577GTn43/jN398iDZS+eAHuz/fERFvvHFioE4W7Nat+c956aX++3iYp5+e/5zz5/vvAxboBzduxN2Njc7H71+9NmA3OZy6+NRcx//g9u2BOoH57P7Ha4c6b++bhztvlZ39yScOdd7WTx3uvHX3Nzc+v+wWWKLHb3xz2S2k4J0FAAAAQMOyAAAAAGhYFgAAAAANywIAAACg0WlZUEr58tCNsLrkjxnITf65yT83+WMGcpN/bjOXBaWUS4tohNUkf8xAbvLPTf65yR8zkJv8mfqrE0spmxGxN/mZdtxOROxExLH+WmPZuuY/OdYMjJDXgNzkn5v8c5M/ZiA3+RMx+50F27XWV2ddpNZ6rda6HRG2T+PSKf8IMzBiXgNyk39u8s9N/piB3OTPzGXBxVLKlYjYLqVcXkRDrBT5YwZyk39u8s9N/piB3OTP9I8h1Fr/OCKilHK11vqVxbTEqpA/ZiA3+ecm/9zkjxnITf5EdPxtCLXW54ZuhNUlf8xAbvLPTf65yR8zkJv8c+u0LAAAAADysCwAAAAAGlO/s+CwPvL0e2Jr6z2dj//Zn+t+LP3Y3X172S3045VX5j/n2Wf77+NhvvrV+c/5xCf67wMW6JGtrTh25kzn408+9dSA3eRw7Ny5uY5/5OTJgTqB+fznXx/i78mI+Mbn/6TnTpbvo5/9ncOd97nDnQcwi3cWAAAAAA3LAgAAAKBhWQAAAAA0LAsAAACAhmUBAAAA0LAsAAAAABqWBQAAAEDDsgAAAABoWBYAAAAADcsCAAAAoGFZAAAAADQsCwAAAIDG8WU3AEdy/BAjfOZM/308zGH6gzV35hd/Ic5sbXU//tlfGrAbDnJ3d3fZLUBERBw7eeJQ5504s9FzJ8t32OcCYCjeWQAAAAA0LAsAAACAhmUBAAAA0LAsAAAAABqdlgWllAullM2hm2E1yR8zkJv8c5N/bvLHDOQm/9xmLgtKKVciYrPWur+Aflgx8scM5Cb/3OSfm/wxA7nJn6m/162UshMRMzdJk+N2IuJYT32xArrm/8CxZmBkvAbkJv/c5J+b/DEDucmfiNnvLHgmIq5GxPZkEA5Ua71Wa92OiEt9NsfSdco/wgyMmNeA3OSfm/xzkz9mIDf5M3NZsBcRNyPixYh4fPh2WDHyxwzkJv/c5J+b/DEDucmfmcuCKxHxQkRsx/3NErnIHzOQm/xzk39u8scM5CZ/pn9nQa31ZkQ8v6BeWDHyxwzkJv/c5J+b/DEDucmfiI6/OhEAAADIw7IAAAAAaEz9GAKsvI9/fP5z3nyz/z4AgLl95IXfWuh5AHTnnQUAAABAw7IAAAAAaFgWAAAAAA3LAgAAAKBhWQAAAAA0LAsAAACAhmUBAAAA0LAsAAAAABqWBQAAAEDDsgAAAABoWBYAAAAAjeM9X69EROzt7fV8Wfr2QEal50ubgTUgfwaaAfmvCfnnJv/c3APkJn/mmYFSa+2tcCnliYj4Vm8XZBHeX2t9ra+LmYG1I396mwH5ryX55yb/3NwD5CZ/Zs5A38uCYxHxvojYj4iDLnw9Ii71VnC6RdVa18dUImIzIr5da73X0zVnzcC6PlerUkv+3ak1W+8zsEJ/B6xrJousNeb8F1lrXR+T/NerTt+13AOsX611zz9iPZ+rsdbqPAO9fgxhUuyh24lSyr1a626fNZdda80f0+s9Xisips/Amj9XS68l/+7U6qzXGViVvwPWPJNF1hpl/ousteaPSf5rUmegWu4B1qjWuucfsdbP1VhrdZoBX3AIAAAANBa9LLg2wlpjfExDGetzNcZZG8IYMxlzrSGM8c/KWGsNYYzP1Rgf01DG+FyN8TENZazP1RhnbShjfK7GWuv/9PqdBQAAAMD68zEEAAAAoGFZAAAAADQsCwAAAICGZQEAAADQsCwAAAAAGgtZFpRSNkspl0spO6WUzYFrXSylXCmlfHnIOg/UW0idSa0LQz9/Qxhz/pOai5o1+c+uNdr8J7XMwPQ68l9BY34NkP9sY85/UtM9wBTy763OWuYf4R6gx1pLmYFFvbPghYi4PvnZGbpYrfX5iHi5lHJxyDqllEtDXv9dta5ExGatdX9RNXs0yvwjFjcD8u9ujPlPapmBDuS/kkb5GiD/zkaZf4R7gI7kf/Q665x/hHuAPmotbQaOL6jOxcmD2y+lfHjIQrXWVx/435tD1ZlsdvYmP4MqpexExFpuEydGl3/E4mZA/t2NMf9JLTPQgfxX1uheA+Q/l9HlH+EeYA7yP1qddc8/wj3AUWstdQbG/J0F+wNvX7bfNZRDeiYirkbE9mRgmG3o/CMWNwPyn9+Y8o8wA/OSP+4BchvTa4D85yf/3MaUf8SSZ2BRy4Kbk8+rXIiIl4cuVkq5XGu9NnCZi5O3hGyXUi4PXGsv7m/IXoyIxweuNYQx5h+xuBmQ/xxGmH+EGehM/itpjK8B8u9ujPlHuAfoSv5Hs+75R7gHOKqlzkCptQ5f5P5w/OhzHS8Oue2ZbFyeiftP7NdrrV8Zqtak3tVa63MD17gQEc9FxNcj4matddC31vRtzPlPag46A/Kfq9bo8p/UMAPd6sh/BY35NUD+s405/0lN9wBTyP/I11/r/CPcA/RQY6kzsJBlAQAAALA+xvydBQAAAMAhWBYAAAAADcsCAAAAoGFZAAAAADQsCwAAAICGZQEAAADQsCwAAAAAGj8El1WAobbKiQwAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 1296x648 with 8 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "f, ax_list = plt.subplots(1,n_gen)\n",
    "for idx, a in enumerate(ax_list):\n",
    "    out = batch[idx]\n",
    "    #print(out.shape)\n",
    "    a.imshow(out)\n",
    "plt.gcf().set_size_inches(18,9)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71a047dd",
   "metadata": {},
   "source": [
    "## Generate ground truth traversals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 147,
   "id": "6d623120",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "n_intervals = 7\n",
    "facts = torch.ones(num_factors*n_intervals, num_factors)*5 # sample intervals.\n",
    "for i in range(num_factors):\n",
    "    facts[i*n_intervals:(i+1)*n_intervals,i] = torch.arange(n_intervals, dtype=torch.float)*(10./(n_intervals-1))\n",
    "\n",
    "batch = syn_dataset.sample_observations_from_factors(facts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "id": "8211d263",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import numpy as np\n",
    "def plot_traversal(batch, n_row, savepath=None, plotsize=(8,6)):\n",
    "    batch = batch.transpose(3,2).transpose(2,1)\n",
    "    img = torchvision.utils.make_grid(batch, nrow=n_row, padding =1)\n",
    "    img = img.transpose(0,1).transpose(1,2)\n",
    "    #print(img.shape)\n",
    "    plt.imshow(img)\n",
    "    height, width = img.shape[0], img.shape[1]\n",
    "    num_img_width = n_row\n",
    "    limits = [-1, 1]\n",
    "    #lines = [r\"$h_1$\", r\"$h_2$\", r\"$h_3$\", r\"$h_4$\", r\"$h_5$\", r\"$h_6$\"]\n",
    "    #lines = [r\"$e_1$\", r\"$e_2$\", r\"$e_3$\", r\"$e_4$\", r\"$e_5$\", r\"$e_6$\"]\n",
    "    lines = [r\"$z_1$\", r\"$z_2$\", r\"$z_3$\"]\n",
    "    plt.yticks((np.arange(0,3)+0.5)*(height/3), lines)\n",
    "    plt.xticks(np.array([0.5,num_img_width/2, num_img_width-0.5])*(width/num_img_width), [str(limits[0]), \"0\", str(limits[1])])\n",
    "    #plt.xlabel(\"magnitude of change\")\n",
    "    #plt.text(0.71, 0.715, f\"dci={dci}\", fontsize=8, transform=plt.gcf().transFigure)\n",
    "    #plt.annotate(, (width-130, 0), fontsize=8)\n",
    "    plt.gcf().set_size_inches(*plotsize)\n",
    "    plt.tight_layout()\n",
    "    if savepath is not None:\n",
    "        plt.savefig(savepath, backend=\"pgf\", dpi=800)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "id": "f3902376",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAEOCAYAAACEvm3bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAACuBJREFUeJzt3bFuXFUawPHvI3kAK+EFYmlfAGV7ipTQRbuS3Wc76EArkGwjxCqUlO6NtRLl0qWAele8wCJnHyCgeQE4W9isWSTm3MncO/fO59+vinKvZo6OzpA/1/bnbK0FAEAlb8y9AACAsQkcAKAcgQMAlCNwAIByBA4AUI7AAQDKETgAQDkCBwAo5/4UL5qZ9yLiUUSsIsIkQQBgDBkRBxHxsrX207obJwmcuI6bf0/02gDA3faHiPh+3Q1TfYlqNdHrAgB0O2OqwPFlKQBgKt3O8E3GAEA5AgcAKEfgAADlCBwAoByBAwCUM9UcnK7W/KDVLzKze4/9utXbL3t1y14N53O4GWdrOGdrM0P2awhPcACAcgQOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoByBAwCUI3AAgHIEDgBQjsABAMoROABAOQIHAChH4AAA5QgcAKAcgQMAlHN/7gWs1dp213clc7vrI2hF9iqXsFfXN02+jq4Be7GI/VrAVkVERPdjuIC9iv04WrvYq4j+fu3DXl3fM//ZWsJeRSznbEV4ggMAFCRwAIByBA4AUI7AAQDKETgAQDkCBwAoR+AAAOUIHACgnGUP+js72+76rpycrL9+ejr5Ev5z9tna6y/PPp18DUM8Ovl4/fXTjyZfw7dnX3Tv+WbAPVN7++S9/j2n70++jsuzy7XXL84uJl/DEMcnx2uvH50eTb6GL86+HXDPN5Ovo+e9k7fXXn//dP31sXx29nLt9U8713fh45NH3Xs+Oj2cfB3+OdycJzgAQDkCBwAoR+AAAOUIHACgHIEDAJQjcACAcgQOAFBOdw5OZh5GxNOIWLXWzqdfEgDAdoY8wfkxIs4j4kVERGY+n3RFAABb6gZOa20VEYettavevZn5LDP/FTcxBAAwh27g3DyxWWXmQe/e1tp5a+1xRDwZY3EAAK9jyBOcD+P6e3B+iZZu6AAAzGnQL9tsrX3+qz//ZbrlAABsz4+JAwDlCBwAoJxsrY3/opkPI+LVunumeN99lZnde+zXrd5+2atb9mo4n8PNOFvDOVubGbJfEfFma+2HdTd4ggMAlCNwAIByBA4AUI7AAQDKETgAQDkCBwAoR+AAAOUIHACgHIEDAJQjcACAcgQOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoByBAwCUI3AAgHIEDgBQjsABAMoROABAOQIHAChH4AAA5QgcAKAcgQMAlHN/rjfOzLneei/Zr+Hs1XD2ajP2azh7tRn7NT5PcACAcgQOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoJzZ5uC01uZ668UZMv/Aft3q7Ze9umWvhvM53IyzNZyztZmxZgJ5ggMAlCNwAIByBA4AUI7AAQDKETgAQDkCBwAoR+AAAOUIHACgHIEDAJQjcACAcgQOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoJz7cy9gSu3ntvVr5Bs5wkrm19qAvdh+uyI625W5H/s5xtnZ1lLOXvfszL9V1xZw9pZwbsYwxtnb2X9zlmCEo1Xl7CyJJzgAQDkCBwAoR+AAAOUIHACgHIEDAJQjcACAcgQOAFBO6Tk479x7Z+vX+Lp9PcJK5nd59mX3nosB9/QcnxytvX50erz1e+zCu52zs4ORQfGPhZy9y7PLtdcvzi52tJL1jk/Wn62j0/Vncwy9cxOxjNEvuzh7l5+sPzcRERenyzg72zoe4b9r7957d+31toiTs70cY2jQQJ7gAADlCBwAoByBAwCUI3AAgHIEDgBQjsABAMoROABAOQIHAChH4AAA5QgcAKAcgQMAlNP9XVSZeRgRTyNi1Vo7n35JAADbGfLLNn+MiPOIeJCZb0XEnyPi76217357Y2Y+i4hnEXFv1FUCAGyg+yWq1toqIg5ba1c3UfO3iHj8O/eet9YeR8STcZcJADBcN3Ay83lErDLz4OavHvtSFQCwZN0vUbXWPszMDyLiKjMfRMRBZh601r6afnkAAJvL1tr4L5r5MCJerbtnivfdV5nZvcd+3ertl726Za+G8zncjLM1nLO1mSH7FRFvttZ+WHeDHxMHAMoROABAOQIHAChH4AAA5QgcAKAcgQMAlCNwAIByBA4AUI7AAQDKETgAQDkCBwAoR+AAAOUIHACgHIEDAJQjcACAcgQOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoByBAwCUI3AAgHIEDgBQzv253jgz53rrvWS/hrNXw9mrzdiv4ezVZuzX+DzBAQDKETgAQDkCBwAoR+AAAOUIHACgHIEDAJQjcACAcmabg9Na697z5emXa69fnF2MtZytHJ8cr71+dHq09vqQ+Qe9/Rqyn/tgyF707qmyF2OwV8ON8Tm8S5yt4ZytzYw1E8gTHACgHIEDAJQjcACAcgQOAFCOwAEAyhE4AEA5AgcAKEfgAADlzDboj3Fdnl2uvV5lKOIYKg3UGmsgFkA1nuAAAOUIHACgHIEDAJQjcACAcgQOAFCOwAEAyhE4AEA5i56D05uJsouZKdTTmxkUsYy5Qb2ZQRELmRu0lLFCnZFAZgbB3eIJDgBQjsABAMoROABAOQIHAChH4AAA5QgcAKAcgQMAlCNwAIByFj3oj+EMRWQqvcGISxiKGNEfjLiIoYgD75lab+ihoYhU4AkOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoByBAwCUYw4Od86QeSjmBvE6ejODIpYxN2gJM4Mi+jOB2s8LmBn0Rn8mkLlBy+QJDgBQjsABAMoROABAOQIHAChH4AAA5XR/iiozDyPiaUSsWmvn0y8JAGA7Q35M/MeIOI+IB7+KnRette9+e2NmPouIZxFxb9RVAgBsoBs4rbVVZr71S9Dc/Lz/6nfuPY+I88x8GBGvxlwoMI/eTBQzg3hdvblB+zAzKGI3n4HeTKDeTKFd6c0EGjJXaCzd78HJzOcRscrMg5u/+ioinky6KgCALQx5gvNhZn4QEVeZuYqIq4h4MfnKAABe06Bf1dBa+3zqhQAAjMWPiQMA5QgcAKAcgQMAlCNwAIByBA4AUM6gn6ICoG/IwDeDEXkdl58sfyhiRH8w4i7Pvyc4AEA5AgcAKEfgAADlCBwAoByBAwCUI3AAgHIEDgBQjjk4AMyiNxPFzCC24QkOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoByBAwCUI3AAgHKytTb+i2Y+jIhXo78wAEDEm621H9bd4AkOAFCOwAEAyhE4AEA5AgcAKEfgAADlCBwAoJypAicnel0AgG5nTBU4BxO9LgBAtzOmGvR3LyIeRcQqIsZ/g/G9iIgncy8C+B+fSVieJXwuM67j5mVr7ad1N96f4t1v3vT7KV57Cpn5U28iIrA7PpOwPAv6XA76TQm+yRgAKEfgXDufewHA//GZhOXZq8/lJN+DAwAwJ09wAIByBA4AUI7AAQDKETgAQDkC50ZmfpCZT+deB9xlmXmQmU8z81lmmogOC7GP/0YKnFtXcy8AiL/G9bTUFxHxbOa1ALf27t/ISSYZA7ymt1prq4hYZeYf514MsL88wQEAyrmzT3BuvpZ4GBHnN//HCMzv6uZ7bx5ExD/nXgywv+5s4LTWvvrNX/0xIpbwS8TgLnseEX+6+fNejYWH4vbu30i/qgEAKMf34AAA5QgcAKAcgQMAlCNwAIByBA4AUI7AAQDKETgAQDkCBwAo579uoCQPs9yUNgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 576x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_traversal(torch.from_numpy(batch), n_row=n_intervals)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43bfab7d",
   "metadata": {},
   "source": [
    "## Test gradient computation of generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 150,
   "id": "46791ca4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a sample to study\n",
    "if use_colorbar:\n",
    "    facts = torch.tensor([[2.6759, 0.787, 2.5879]]).reshape(1,3)\n",
    "else:\n",
    "    facts = torch.tensor([8.4, 9.9, 0.1, 8.4]).reshape(-1,4)\n",
    "#facts.requires_grad_(True)\n",
    "batch = syn_dataset.sample_observations_from_factors(facts, ret_torch=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "id": "ebc259d9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f6ddc1ad978>"
      ]
     },
     "execution_count": 151,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ4AAAETCAYAAADH+ejgAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAACMpJREFUeJzt3b9uW+cdx+HvWxfdigqSh65R/nToJsjo0jG5A7dZOynoXMBGpqJToQC9AOkGitpBbyC+AzvZmybKXlgEt07B20HHgZJYPPw15CEZPg8gwIEJvq8V6sNzjg/9a733AFT8ZNMbAHaPcABlwgGUCQdQJhxAmXAAZcIBlAkHUCYcQNlPV/lkrbV7Sd5IMk/illTYHS3JQZKveu9fjz14peHITTT+veLnBKbzdpIvxh606nDMk+Tzzz/P4eHhip8aWJfZbJZ33nknGX6Gx6w6HD1JDg8Pc3R0tOKnBiaw1CUGF0eBMuEAyoQDKBMOoEw4gDLhAMqEAygTDqBMOICyhXeOttYOkryb5DDJk977UrejAj9uY0ccHyZ5NnydrX87wC4Y+6zKyXCUMW+tPbjrQa21s9yE5d4qNwdsp5Vc4+i9X/beT3NzWgP8yI0dcVwN1zkOkzyfYD/ADhgLx3mS3w+/vlzzXoAdsTAcvferCAbwHe7jAMqEAygTDqBMOIAy4QDKhAMoEw6gTDiAMuEAyoQDKFv1CMi9M//X6HzetXjxl79tZN1NOf3znzay7sGv3trIutvOEQdQJhxAmXAAZcIBlAkHUCYcQJlwAGXCAZQJB1AmHECZcABlwgGUjYajtfaotfZwis0Au2GZI46rte8C2Ckr+Vi9afWwX0yrB8pcHAXKljlVeZDket0bAXbHaDh674+n2AiwO5yqAGXCAZQJB1AmHECZcABlwgGUCQdQJhxAmXAAZcIBlJlW/wP99z+b+RjPF3//50bW3ZRf//EPG1nXtPrXc8QBlAkHUCYcQJlwAGXCAZQJB1AmHECZcABlwgGUCQdQJhxAmXAAZQs/5NZaO0nyfpLj3vvvptkSsO1GjziGgUzPh4gALD7i6L1/dus/r+56nGn1sF+WvcYx773P7/pN0+phv4yGo7X2sPd+OcVmgN2wMBzDKcj7rbWL1trDifYEbLmxaxyXSRxtAN/iPg6gTDiAMuEAyoQDKBMOoEw4gDLhAMqEAygTDqBMOIAy0+p/oJ/94ucbWfeXv/3NRtbdlE19n3k9RxxAmXAAZcIBlAkHUCYcQJlwAGXCAZQJB1AmHECZcABlwgGUCQdQNjaQ6bi1dtZaezrVhoDtN3bEMRuGMt05cBrYP2OT3OattZMkB4seZ1o97JfRaxy998+SZAjIXY8xrR72yLIXR58mOVznRoDdsfBU5faE+t77x+vfDrALxq5xiAXwPe7jAMqEAygTDqBMOIAy4QDKhAMoEw6gTDiAMuEAyoQDKGu999U9WWtHSV6+fPkyR0dHK3teYL2ur69z//79JLnfe78ee7wjDqBMOIAy4QDKhAMoEw6gTDiAMuEAyoQDKBMOoEw4gDLhAMqEAyhbKhym1QO3jYajtWYeLPAtYyMgD5LMhq9FjzOtHvbI2BHH6atp9YuYVg/7ZSwcJ6218ySntwdQA/ttbOj0R0nSWrswgBp4Zam/Vem9f7DujQC7w30cQJlwAGXCAZQJB1AmHECZcABlwgGUCQdQJhxAmXAAZcIBlAkHUCYcQJlwAGXCAZQJB1AmHECZcABlwgGUCQdQJhxAmXAAZcIBlC0zdPpi+DqeYkPA9hsbOv1uknmSv/be59NsCdh2Y0ccL5J8meTTYXL9a7XWzlprL5I8W+XmgO00Njt2nuSytZbcTKJ/7fzY3vvl8LijJC9XvUlguywMxy1XSWbr3AiwO8aucZwNv5z13p2GAEnGT1Uup9oIsDvcxwGUCQdQJhxAmXAAZcIBlAkHUCYcQJlwAGXCAZQJB1AmHECZcABlwgGUCQdQJhxAmXAAZcIBlAkHUCYcQJlwAGXCAZQJB1AmHEDZUuForR0vmh0L7JfRcLTWzpMcmFYPvLLMCMjRI43hcWdJ7q1oX8AWGzvieC/JRZLTW3Nkv6f3ftl7P83NRHvgR24sHLPcTKp/kuTN9W8H2AVj4ThP8mGS09wceQCMTqu/SvJ4or0AO8J9HECZcABlwgGUCQdQJhxAmXAAZcIBlAkHUCYcQJlwAGXCAZQJB1AmHECZcABlwgGUCQdQJhxAmXAAZcIBlAkHUCYcQJlwAGVjIyAf5WYQ02GST3rvl5PsCthqC8OR5LL3Pm+tPUzy2RQbArbfwlOVWxPqHwzDmQBGjzheuV70m6bVw34ZvTjaWjvJyGmKafWwX5b5W5V3e+/P1r4TYGeMhqP3/tEUGwF2h/s4gDLhAMqEAygTDqBMOIAy4QDKhAMoEw6gTDiAMuEAypb9dOyyWpLMZrMVPy2wTrd+Ztsyj2+995Ut3lp7K8m/V/aEwNTe7r1/MfagVYfjXpI3ksyT/D9P/Cyb+Wi+da277+u2JAdJvuq9fz324JWeqgwLjtbqLq21r3vvC//RoHWwrnWtmyR5uewDXRwFyrYtHJv6V9Sta13rFqz0GgewH7btiAPYAcIBlAkHUCYcQNlWhKO1dtBae9haO2utHUy89qNhxOWUa5601s5ba08nXvd4+B5Puu6t9Sdft7V2MXwdb2Dt4ylfz8Nr+aK19nQYkrY2WxGOJB/m5m63Z7mZCDeljYy27L0/TvJ8GHg1ldkwOHzyP3NrbfI7KIc150keTz3CtLV2nuTg1hjVKVz23j9I8o/c/CytzbaE46T3Ph/+5z7Y9GbWrfd+ezLeZC/oYYD4SW5uLZ7M8K47G76m9CLJl0k+nfid/ywTf4+TaWc9b0s49tV84nekb6I18ZHO6XdiOYnhzegyyXmm/czIe0kukpyu+5ThDmu/zX1bwnE1XOc4TvJ805uZQmvt4fCi3oSnSQ4nXO9kOHQ/nfp60uAq056ezYb1niR5c8J1l5r1vJJ1tuHO0SEYr94Rnkz5Ljy8oK+nHHU5vAu9l5sX2Ce9948nWvebH9qp1vzO+hfDOfhU6716t59N+ecdXs8fJPkkydWU11daa4+meC1vRTiA3bItpyrADhEOoEw4gDLhAMqEAygTDqBMOIAy4QDK/gcIeQHgR5SKvQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 460.8x316.8 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "f, ax = plt.subplots()\n",
    "ax.imshow(batch[0].detach().numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c63ce70-d579-4243-bf49-e5ab6ce166ce",
   "metadata": {},
   "source": [
    "We use the generator attribution function implemented in the ```attributions``` folder to compute the generators attributions. However, this function expects a generator object with a ```decode``` function. To this end, we construct a Wrapper class around the ground truth generator."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "id": "c3af1d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GTGenWrapper():\n",
    "    def __init__(self, gt_process):\n",
    "        self.my_gt = gt_process\n",
    "    def decode(self, latent):\n",
    "        return self.my_gt.sample_observations_from_factors(facts, ret_torch=True)\n",
    "    \n",
    "from common.attributions import generator_jacobian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 153,
   "id": "4c4da63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = syn_dataset.sample_observations_from_factors(facts, ret_torch=True)\n",
    "jac = generator_jacobian(GTGenWrapper(syn_dataset), facts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 154,
   "id": "25b34c35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.1630)\n",
      "tensor(1.3911)\n",
      "tensor(3.1750)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD4CAYAAADB0SsLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADXVJREFUeJzt3cFuHNeVxvHviIzFGWMAomVkFQSwJhK0DDj2G9DrbJjxCwTOcnYivBggwCwMeTe7MZEXGEOb7AKYeYLIWhuxTAHBrAZyo4GMEFJi88yiqu0S1V3nVrG6u+rW/wcURHXfZl+o9PW9VXW7jrm7AOTp1rY7AGB9CDiQMQIOZIyAAxkj4EDGCDiQMQIOZGx32x0AxsTM9iUdSppI+tLdZ+XjB5I+lnTX3X+9ql1TjODAZn0q6bTcPqk+4e7Hkv5chn1luyYYwYEaZvaupL3E5ufu/jJoc1COxjMz+3DxoLs/rbQ5W9WuKQIOrGBm774r/V+U2Iq/mdkzSVfl30/c/aTh287cfWZmDV+2HAEHVtt7KenfJP1D0PDvkv5T+idJH7n79zVNz8rj64mkP1efMLOjygfCynZNJAfczHYkvS9pJolvqLRjkvYlPXf3eeMXsw+60HgfvCvpH4M2DU5mPZL0r+XPJ2WI70r6QNJHZvaRpK+ut0v/9W+y1G+TmdkvJH3b9o3whnvu/qzpi9gHnQr3gZndkfTi31WEvM5LSf9R/PheMIJvVJMp+kyS/vLNN5pMJmvqTt6m06nuP3gglf+WLbAPbqjNPrileITu6+WoJgF3SZpMJrpz586aujMabafX7IPuJO+DsQQcGCUCDmRsV3FQ+hqkvvYL6A0rt6hNHxFwIHBL0k5Cmz4i4ECAY3AgYxyDt/TX/0n/3Pv977t979/8Jq3dz392FTcaMNv5VXJbn/9hK+/d9fs2xQgOZIyAAxkzxQHmLDowUFwmAzL2k3KL2vQRAQcCHIMDGSPgQMYIOJCxHcVBiZaybgsBBwKcRW9p1uC+Jn/8Y7fvfXSU1u7nP+v2fTE8O4pH6L6O4H09dAB641bilsLM9s3syMw+KW+4WH3uoZkdVf7+RbndvUnfAdRYHIPXbQ1G8LqKJWeLH8zsUMV9447d/UwtEXAg0OUIrrJiSRnauoolTyR9J+nr6yN9E5xkAwINL5OdmtnifuttKptIksqyRSdlhZNDSY/b/B4CDgQaBvywbWWTVe0lTRPaLUXAgUDHN3xYWtmkLD74oaTvJcnMFsfnU3c/bdLflv0CxqnLlWzlsff1afvT8rnjSrvW5YqqCDgQ4PvgQMZYiw5kjJsutrS3l972l7/c3ntj3BjBgYzxZRMgY0P+sgkBBwKcRQcyxkk2IGOcZAMyRsCBjBFwIGPcdBHIGCN4S/d/kV6a9+S/1tiREdtmad5tlwVORcCBjBFwIGNcBwcyxko2IGNM0YGMEXAgY10eg5c3WTxUcVfVL8vbIy+eeyjpzN0f17Vroq8fPEBvdFz4IKmySdAuGSM4EOi48MFBORrPzKyusklqu1oEHAiYpLLCyOo27osfo8IHG0XAgcjenhQEXO7S+XnKb0utbNK0AspSBByI7O6mBTxNUmWT6+2adLeKgAORnR3pVnAUfpX2vYoGlU2WtWuMgAOR3d3OAr5pBByI3L5djOJ15vP657eEgAOR3d044NEx+pYQcCBCwIGM7ewUIR+gYfYa2KTdXQIOZGtvLw745eVm+tIQAQciKVP09IUuG0XAgUjKFJ2AAwNFwIGM3b4tvfNOfZtopduWEHAgkjKCs1QVGKiUk2wsVQUGKmUEJ+DAQO3tFcfhdViqCgxUyhSdhS7AQKVM0Xu6lLWfvQL6hIADGbt9uzgOr8NCF2CgOhzBV1UsWfa4mX1RvuxReY+2xvq5/Abok8VJtrotuiHEj1ZVLHnjcTM7lDSTdNw23BIBB2JRuN8c4U/N7Em5LSs5dODuszK0H9Y8/kTSd5K+Lkf3dl1v+0JgNPb24mPwH5eqdlLZpJy6n5QVVQ4lPW7zewg4EEm5Dp4+RV9VsWTl45KmyX29hoADkW4vky2tbLLk8cX0furup436W+1W2xcCo9FhwOsqm1x7/MZVTSQCDsRSroNnsFTVJGk6bX04MHqVf7u230xgH9xQq30wkpVs+5J0/8GDNXVlVPYlvWj5OvZBN9L3wUgC/lzSPRUX3/u5Lq//TMV/rOctX88+uLnm+6Dbs+gblRxwd59LerbGvoxFm5FbEvugQ832wUhGcGCcUk6yvXq1mb40RMCBCCM4kDECvl1mdlfSkaSZu3eyQADNZbsfxnCSreemKlb+TCTJzB65+/F2uzRKP+wHMzuQ9LGk/3b3p/Uv67mUL5tcXGymLw1lEfDyy/EHg/+PNHDX94OZnalYXz3s/TLgKXoW3wc3s0eSZjf53ixubsl++CCLqXq3N3zYqH5+7DTk7sdm9lDFV+seq1zxhc2q7gczm0jaN7N9d2/1XebeGPAI3s9eteDun1d+/u02+zJm1f2QjZTr4FFhhC3JJuDA2jCCAxkj4EDGuA4OZCzlOnj0/JYQcCDCFB3IWIdT9NTKJmXzt9o1lcVCF2CtmhU+iCRVNqlp16zrbV8IjMXV7ju62n0nbFM6NbN5+fPJkpV8B+VoPDOztyqbVB7fX9GuEQIOBC4v45umVp7vpLJJVwg4EGgY8EhqZZN/XtGuEQIOBObzOMDzef3zFUmVTVQEu/r3Vsx7Wrgc2DYzuyPpxbffvtBkcqe27XT6ve7de0+S3mOKDgzI+XmxRW36iIADgY6n6BtFwIFAxyfZNoqAAwECDmTs4iI+xu7pPRcJOBBhBAcyxkk2IGOM4EDGzs/jeypyHRwYKKboQMaYogMZI+BAxkZxHdzMdiS9L2kmia+gtWMqyio9d/eeHrXhurGM4O9L+nZdHRmZe5KebbsTSDOWk2wzSfrLN99oMpmsqTt5m06nuv/ggVT+W2IYxjKCuyRNJhPduVP/5XeEOMQZkPPz+KapXAcHBmosU3RglMYyRQdGaRMBT6144u4zM/uifNkjdz+r+70EHAhcXMSViTq4Dv6ppM9UBPkTSZ8ve9zMnqo4SftZSjkjShcBgcUIHm2lUzN7Um5NSg4duPusHJHfqnhSefyJpO8kfV2O7rUYwYFAnyqblKP2iZlJxdT9cV37rQb8dz0tml71u76eHsXGdHkWvXJMXXWq9IonPzwuaRq9HyM4EOjyJFs5Ar816ppZUsWTyrR/6u6n0fsRcCBwcSEVM+L6NjdRHmNfL1H0tPyz+nijMkYEHAhwHRzIGAEHMsZSVSBj5+dSVIR38Dd8AMbq8jJeycYUHRgopuhAxi4vpVvBom5G8CVYJYYhuLiQrq7q27x+vZm+NMUIDgQuL+OFLozgwEARcCBj83kc8L4ebRJwIJBy00VGcGCgUsJLwIGBSpl+M0UHBuryMl6qSsCBgbq4iBe6RNfJt4WAA4GUlWwEHBgoAt6S7fxqm2+fxOd/2HYXsGXzeRzg6Bh9WxjBgUDKSrabBnxVZZPyuYeSztz9cV27ZSh8AATOz9O2G/pUxe2TT1VUNqk6S2z3FkZwIHQl9+gg+4fnT81scdHsxN1T74J6UI7GMzP7sIN2kgg4kGBeblEbSWuubNIUAQdCjQJeq0Vlk+tS20ki4ECC1+UWtYk1rWzi7k9VFB1czAreaBe9HwEHQleKR+ibXQivq2zi7sdBu5UIOBDqboq+aQQcCBHwVlglhmG4VHyM3c8vhDOCAyFGcCBjBBzIGAEHMsYxOJAxRnAgYwQcyBgBBzLGMTiQMUZwIGMEHMgYAQcyxjE4kDFGcCBjBBzIGAEHMnal+JZM/axdRMCB0GtJrxLatJda2aT8+xflU4/Ke7StRGUTIDRP3CQVhQ+elFtYeaQiqbKJmR1Kmkk6jsItMYIDCTZS+CC1YskTSXclfW1m/xLVJiPgQMgVH2NvprxoGegTK6ohHmrJPdarCDgQeqX4GDx6vtBBZZOFM0nTqBEBB0LdXSa7aWWTynH91N1Po/cj4EBo/dfBG1Q2Sa5qIhFwIEF/jsGbIuBA6JWknyS06R8CDoRYqgpkjIADGVt/+eB1IeBA6LXiteY3W4u+LgQcCDFFBzJGwIGMcR0cyNgrSTsJbfqHgAMhpuhAxrhMBmSMe7IBGXul+O5mHIMDA8UxOJCxcRyDmyRNp+FdYrBC5d/OttkPNDWOEXxfku4/eLCmrozKvqQX2+4EUo3jJNtzSfdU3JO5n8t2+s9UhPv5tjuCJl4pnnTd7CTbqsIHZnYg6WMV92f7dV2BhGWSA+7uc0nPWvYfP2LkHpxzxSP0jc+ifyrpMxXB/UTS54sn3P3YzB5Wwr603TKcZANWO5f0v9KffprY/m+SvjKzxafBSYObJC4tfFDeVXXhbFW7VQg4sIK7vzSzu5L2El9y7u4v19SdmbvPyoIHyQg4UKMMbGehbVP4wMyOKjOBRgUSzH3458vKT9kjFZ9yje4bDfRB+X94Efwvyz/vSvpA0kcqqph8peJe6T+0i06y5RLw/fLHSfnnkaTTa8cvwOhkUT64/BS76+5nlZKqtZ9swBhkEfCyrtOsMpI/1tvHOcDoZDFFlyQze6jiMsKs/FOeUCAdyFk2AQfwtiym6ACWI+BAxgg4kDECDmSMgAMZI+BAxgg4kDECDmSMgAMZI+BAxgg4kDECDmSMgAMZI+BAxgg4kDECDmSMgAMZI+BAxgg4kDECDmSMgAMZI+BAxv4fW3mnA16pd3cAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 288x288 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "f, ax_list = plt.subplots(2, 2)\n",
    "\n",
    "\n",
    "ax_list = list(ax_list[0]) + list(ax_list[1])\n",
    "\n",
    "if num_factors ==3:\n",
    "    ax_list[-1].axis(\"off\")\n",
    "    ax_list = ax_list[:-1]\n",
    "mres=0\n",
    "for idx, a in enumerate(ax_list):\n",
    "    out = torch.sum(jac[0, idx,:,:,:], dim=-1)\n",
    "    print(out.norm()) # Print the norm of the attribution \n",
    "    im = a.matshow(out, vmin=-0.2, vmax=0.2, cmap=\"seismic\")\n",
    "    a.set_xticks([])\n",
    "    a.set_yticks([])\n",
    "    a.set_xlabel(f\"$z_{idx+1}$\")\n",
    "#plt.tight_layout()\n",
    "cbar = plt.colorbar(im, ax=ax_list, shrink=0.95)\n",
    "#cbar.ax.set_ylabel('gradients (channel sum)')\n",
    "#plt.gcf().set_size_inches(2)\n",
    "plt.gcf().set_size_inches(4, 4)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a143ad2-12fd-46f9-a1ec-672555ef6b69",
   "metadata": {},
   "source": [
    "## Test for orthogonal / disjoint attributions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81b72993-ca7e-4aef-bfd2-6bd98d442503",
   "metadata": {},
   "source": [
    "Let's check if the attributions of the dataset are orthogonal (IMA) or disjoint (DMA)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 155,
   "id": "aab1810d-eb24-4b71-8850-f5fc3d4a4640",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 8, 8, 3])"
      ]
     },
     "execution_count": 155,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jac[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 156,
   "id": "369fe09c-c989-474f-9af1-4c0d28152d3c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deviation:  tensor(5.6578e-08)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f6df711f780>"
      ]
     },
     "execution_count": 156,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARQAAAEZCAYAAABW7tqnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAABrNJREFUeJzt2zFuXNcVx+FzQjpOOkEB0juwq3SEVKVJIddpGHsHXIIFL0FLCJeQZANGVLhLIUBIb4PuUoXCVDZgQzwpNAhcmEMi/pOPb+b7AAIk5uHx4Ar86c7Fm56ZAkj4xdIDAPtDUIAYQQFiBAWIERQgRlCAGEEBYg4qKN39qLtPu/usux8tPc+adfdn3X269Bxr1t0n3f2iu/+29CwpBxWUqvq8ql5uv84WnmXtLpYeYB/MzPOqetXdJ0vPknC89AD37GRmNlW16e6nSw/DYZuZ1z/6cS8CfWg7FHiINtv/6Fbv0HYoF9uzk8dV9WrpYaC7T2fmfOk5Ug4tKC+q6pPt93vzj7iQp1V1ufQQa9bdZ1X1cXd/XFX/mJm/Lz3Tz9U+bQykOEMBYgQFiBEUIEZQgBhBAWIEBYgRFCDmIIOyfaCIn8k6ZuzTOh5kUMonjVOsY8berGP00fvuPqqqD6pqU1UP+RHco+7+zdJD7AHrmPHQ17Gr6lFVfTMzb3demHz0vrs/rKqvYjcEHpKPZubrXRekPxy4qar6w6//VO/1++FbH5arb79begSoqqof6vv6Z31Rtf373iUdlKmqeq/fr1/2r8K3PixXfbX0CPDO/MR31zjUQ1ngDggKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMce7XuzuR1X1rKoeV9VfZ2ZzL1MBq3TTDuXzqnq5/Tq7+3GANdu5Q6mqk+2uZNPdT6+7qLvP6l1wjpLDAesSOUOZmfOZeVLv3h4BB+qmHcrF9hzlcVW9uod5gBW7KSgvquqT7ffndzwLsHI7gzIzFyUkwC15DgWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUICY47u46dW339VVX93FrQ/G2z+eLD3CXjj68vXSIxwUOxQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYm4MSnd/1t2n9zEMsG632aFc3PkUwF44Ttyku8+q6qyqjhL3A9YpcoYyM+cz86SqniXuB6yTQ1kg5jZveZ5W1eVdDwKs341BmZnn9zEIsH7e8gAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChBzvPQA/LSjL18vPcJe+OLf/1p6hNW7fPO2fvv7211rhwLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQc7zrxe4+qapPq+p3M/Pn+xkJWKsbdygz87yqXm3jAnCtnTuUmXn9ox8vrruuu8+q6qyqjkJzASt02zOUzcxsrntxZs5n5klVPcuMBazRjUHp7tOZOb+PYYB12xmU7VuZT7v7L919ek8zASt10xnKeVXZnQC34jkUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYo7D9+uqqh/q+6oJ3xn+D5dv3i49wuq92fxvDfuma3sm95ff3R9W1VexGwIPyUcz8/WuC9JBOaqqD6pqUw97j/Kyqp4tPcQesI4ZD30du6oeVdU3M7Nzyxd9y7P9ZTsL9hB099uZuVx6jrWzjhkrWcf/3OYih7JAzKEG5XzpAfaEdczYm3WMnqEAh+1QdyjAHRAUIEZQgBhBAWIEBYj5LznA4zuUEpfAAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 316.8x316.8 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "## Orthogonal gradients (IMA)\n",
    "jacobian = jac[0].reshape(num_factors, -1)\n",
    "prod = jacobian @ jacobian.t()\n",
    "print(\"Deviation: \", torch.sum(torch.abs(prod-torch.diag(torch.diag(prod)))))\n",
    "plt.matshow(prod)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6ec2891-9ce7-4846-8798-8849741e78b1",
   "metadata": {},
   "source": [
    "Looks good for ```Colorbar``` and ```Fourbars```: Both have orthogonal jacobians."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 157,
   "id": "41418c6f-a40b-4146-b5d8-d1c00fa3f5ac",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deviation:  tensor(1.0501)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.image.AxesImage at 0x7f6ddda3fe48>"
      ]
     },
     "execution_count": 157,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARQAAAEZCAYAAABW7tqnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAABspJREFUeJzt27GKXOcZx+H31axjl2sF0jvYVaoIqQmkk7uUG/sO9hIsfAm6hOwlOKnTWGU6gSGdwWbduIzElMZCel14CG40u6D/7tkz8zywsMscvn35Fv30ncNMz0wBJNxbegDgcAgKECMoQIygADGCAsQIChAjKEDMUQWlu0+7+6y7z7v7dOl51qy7v+jus6XnWLPuftDdT7v7n0vPknJUQamqL6vq2e7rfOFZ1u5y6QEOwcw8qarn3f1g6VkSTpYe4JY9mJltVW27+9HSw3DcZuab3/x4EIE+thMK3EXb3X90q3dsJ5TL3bOT+1X1fOlhoLvPZuZi6TlSji0oT6vqs933B/NHXMijqnqx9BBr1t3nVfVpd39aVV/PzL+WnuldtU8bAymeoQAxggLECAoQIyhAjKAAMYICxAgKEHOUQdm9oYh3ZB8zDmkfjzIo5ZPGKfYx42D2MfrW++7eVNVHVbWtqrv8FtxNd/9+6SEOgH3MuOv72FV1WlU/zMzrvRcm33rf3R9X1XexBYG75JOZ+X7fBekPB26rqv5y8rd6r98PL31c5tXPS49wEDb3P1x6hNV79ean+s/2q6rdv+990kGZqqr3+v36naC8k+leeoSDsLn3wdIjHJIrb2eO9aEscAMEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgJiTfS9292lVPa6q+1X11cxsb2UqYJWuOqF8WVXPdl/nNz8OsGZ7TyhV9WB3Ktl296O3XdTd5/VrcDbJ4YB1iTxDmZmLmXlYv94eAUfqqhPK5e45yv2qen4L8wArdlVQnlbVZ7vvL254FmDl9gZlZi5LSIBr8j4UIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYk5uYtF59XNN900sfTTe/PXPS49wEDbf/rj0CAdgc+0rnVCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiLkyKN39RXef3cYwwLpd54RyeeNTAAfhJLFId59X1XlVbRLrAesUeYYyMxcz87CqHifWA9bJQ1kg5jq3PI+q6sVNDwKs35VBmZkntzEIsH5ueYAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYg5uYlFN/c/rM29D25i6aOx+fbHpUc4CP/+79dLj7B6L16+rj/86XrXOqEAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxAgKECMoQIygADGCAsQIChAjKECMoAAxggLECAoQIyhAjKAAMYICxJzse7G7H1TV51X1x5n5++2MBKzVlSeUmXlSVc93cQF4q70nlJn55jc/Xr7tuu4+r6rzqtqE5gJW6LrPULYzs33bizNzMTMPq+pxZixgja4MSnefzczFbQwDrNveoOxuZT7v7n9099ktzQSs1FXPUC6qyukEuBbvQwFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCBGUIAYQQFiBAWIERQgRlCAGEEBYgQFiBEUIEZQgBhBAWIEBYgRFCDmJLxeV1W9evNTeNljtFl6gIPw4uXrpUdYvZfb/+9hX3Vtz0zsF3f3x1X1XWxB4C75ZGa+33dBOiibqvqoqrZVlVs471lVPV56iANgHzPu+j52VZ1W1Q8zs/fIF73l2f2yvQW7C7r79cy8WHqOtbOPGSvZx/9d5yIPZYGYYw3KxdIDHAj7mHEw+xh9hgIct2M9oQA3QFCAGEEBYgQFiBEUIOYXKuviu8NEEW4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 316.8x316.8 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Disjoint gradients (DMA) \n",
    "jacobian = jac[0].reshape(num_factors, -1)\n",
    "prod = torch.abs(jacobian) @ torch.abs(jacobian).t()\n",
    "print(\"Deviation: \", torch.sum(torch.abs(prod-torch.diag(torch.diag(prod)))))\n",
    "plt.matshow(prod)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09bb39ff-3a9a-43a0-a038-d39da009af32",
   "metadata": {},
   "source": [
    "Not diagonal for ```Colorbar```, which does not have disjoint jacobians. ```FourBars``` features disjoint jacobians."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 158,
   "id": "476cf985",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([21, 3])\n",
      "torch.Size([21, 8, 8, 3]) <class 'torch.Tensor'>\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjgAAAEOCAYAAACEvm3bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADX9JREFUeJzt3WGs3WV9B/Df03agnY13bVHYxIarDZBBVqD6YpvTJRAIE6ZLM4yJLyShizWZ2TJhXTIVXNaBb+YS53aXuGwazZYuGliIhvoCTPbGAmvI4kyxjKyBbdxLrgoJLOU+e9GLZUT+z/9wzv+ce37383nV9Pnl//zO0/+f++U55zy31FoDACCTLbNuAABg0gQcACAdAQcASEfAAQDSEXAAgHQEHAAgHQEHAEhHwAEA0tk2xEVLKVsj4pKIWI0IJwkCAJNQImIhIp6otb7UVThIwImz4ebkQNcGADa3vRHxeFfBUG9RrQ50XQCAZs4YKuB4WwoAGEozZ/iQMQCQjoADAKQj4AAA6Qg4AEA6Ag4AkM5Q5+A01eqLVi8rpTRrrNc5rfWyVudYq/48h6OZxr115PBDneM/Xn1x7Dl2LJzfOX74yK+NPUefe+sLH/tC53iWe6/PWnz8ix+fyFx2cACAdAQcACAdAQcASEfAAQDSEXAAgHQEHAAgHQEHAEhHwAEA0pnZQX8A0OXo3/1b5/h/P/3c2HO89aI3dY5P4qC/Pu7/4v2d4zWSHPQX7YP+JsUODgCQjoADAKQj4AAA6Qg4AEA6Ag4AkI6AAwCkI+AAAOk4BweADenSK3Z3jl/w1p8de46dF7xx7GtMwuJVi90FOY7BiV7H4Dw6mans4AAA6Qg4AEA6Ag4AkI6AAwCkI+AAAOkIOABAOgIOAJCOgAMApFNqnfzpQaWUXRGx3FUzxLzzqpT2yUfW65zWerXW6siR9lo++eRILQ1iz552zeHD3Wsx7lptJp7D0bi3+nNvjabPekXE7lrrSleBHRwAIB0BBwBIR8ABANIRcACAdAQcACAdAQcASEfAAQDSaZ6DU0pZjIgDEbFaa13qdVHn4IzEGQmjGff8jf3722v58MMjtTSIa65p1xw/7hycSZnWc/jY736yc/zJL3157Dn23PqRzvEr/+JzY8/h3urPf+NHM81zcJ6NiKWIOLY+8d19ZgYAmJVmwKm1rkbEYq31VKu2lHKwlHI81sMQAMAsNAPO+o7NailloVVba12qte6PiGsn0RwAwOvRZwfnjjj7GZyXQ0sz6AAAzNK2PkW11nte8effGa4dAIDx+Zo4AJCOgAMApNPrLaohPPPMmWbNDdf95xQ6Gd43H7h48DnW1ta6CzbKEQuN4w22bBk+c192WbumtZzT0KfPcd3ziX9s1jzy4MnhG5mCq9+7d9YtRETESy++2D3+/PODzwGbgR0cACAdAQcASEfAAQDSEXAAgHQEHAAgHQEHAEhHwAEA0hFwAIB0ZnbQ35n2OX9x4sQLwzcyBX1e67h+/92/1zl+8uHHh2+ih73XvLNz/M+Pf37wHr7ylcZpg5vIU0+sNGu+f+L0FDoZ3oVv3znrFiIi4pJDt3WOX/gbN4w9xxv3DH+4KGx0dnAAgHQEHAAgHQEHAEhHwAEA0hFwAIB0BBwAIB0BBwBIZ2bn4Gzf3j6L5GOHfm4KnQyvz2uFWfjVG69o1lx4cY7n8J1X/kLn+Ofvm04fb/6lK8caB/qxgwMApCPgAADpCDgAQDoCDgCQjoADAKQj4AAA6Qg4AEA6Ag4AkE6ptU7+oqXsiojliV8YACBid611pavADg4AkI6AAwCkI+AAAOkIOABAOgIOAJCOgAMApCPgAADpbJvVxEOcvzOvSinNGut1Tmu9rNU501irH9z3QLPm6zd/dKw5Pnjv3zZr3nHTdWPN4TkcjeewP/fWaPqsVx92cACAdAQcACAdAQcASEfAAQDSEXAAgHQEHAAgHQEHAEhnZufgAGSzcuTP2jV/2q4Z2q4/+sPu8cPd4zAP7OAAAOkIOABAOgIOAJCOgAMApCPgAADpCDgAQDoCDgCQjoADAKTjoD9gLO+46bpmzR/U01PoZPbqi//brFl77rkpdNKtT5/TUGttFUynkS6l9Chp17TMxVpMwgTWqi87OABAOgIOAJCOgAMApCPgAADpCDgAQDoCDgCQjoADAKTjHByACdnxoVuaNedftW8KnXQ779JLZ91CRESs3PnZzvHlO++aUievbfenP9Wu+Uy7puUbv3lr5/gP7ntg7Dk2gj7nZk2KHRwAIB0BBwBIR8ABANIRcACAdAQcACAdAQcASEfAAQDScQ4OvA7f/PrJsa9xwwf3TqATNpLzL2ufL9OnBhifHRwAIB0BBwBIR8ABANJpfganlLIYEQciYrXWujR8SwAA4+nzIeNnI2IpInaWUq6OiFsi4h9qrY+8urCUcjAiDkbE1ol2CQAwguZbVLXW1YhYrLWeWg81RyJi/2vULtVa90fEtZNtEwCgv2bAKaXcHRGrpZSF9b/a760qAGAja75FVWu9o5Rye0ScKqXsjIiFUspCrfXo8O0BAIyu1Fonf9FSdkXEclfNM//zQvM6N1/3L5NqaabufeCXO8cveMsbmtdo/TvVtbXWBZpzTEUp3cNb2l/sK41rDHFPv9rbyufGvsbp+skJdNJtI6zVvGitVYT1eiX3Vn/urdH0Wa+I2F1rXekq8DVxACAdAQcASEfAAQDSEXAAgHQEHAAgHQEHAEhHwAEA0unzu6gGceZM+zv/j5344RQ6GV6f1zqulbv+pHN8+c67Bu+hj92f/lT3+Ge6xzeKN+04b9YtANDBDg4AkI6AAwCkI+AAAOkIOABAOgIOAJCOgAMApCPgAADpCDgAQDozO+hv+/atzZqDhxan0Mnw+rxW5su//+gTs24BgA52cACAdAQcACAdAQcASEfAAQDSEXAAgHQEHAAgHQEHAEin1Fonf9FSdkXE8sQvDAAQsbvWutJVYAcHAEhHwAEA0hFwAIB0BBwAIB0BBwBIR8ABANIRcACAdLbNauIhzt+ZV6WUZs3Y6/XVr7ZrvvOd8eaIiHjPe7rHP/zhsadorZd765xprNUzj32vWfOvf/nlsebYd+gjzZoLrrx8rDmm8hwm4jnsz701mj7r1YcdHAAgHQEHAEhHwAEA0hFwAIB0BBwAIB0BBwBIR8ABANIRcACAdGZ20B9T9uCD7ZqlpfHnWVvrHp/AQX9sLD/6j9PNmhN/9fdjzbF44683a8Y96A/IxQ4OAJCOgAMApCPgAADpCDgAQDoCDgCQjoADAKQj4AAA6cz1OTi1debKlJQtc5ATL764XbNv33TmIZXz3ryjWfOWfVcMPgdsZrXWHjXjzVFKn5oeRVMyBz+ZAQBGI+AAAOkIOABAOgIOAJCOgAMApCPgAADpCDgAQDoCDgCQTulzONDIFy1lV0Qsd9VMYt5v/fzezvEXnv6vsed4w0UXNmuuf+rkWHP0ORhpiH+nedVaL2t1jrXqz3M4mkncW08/+Wzn+JeOfGuknoZw6+HrmzUX7dnZOT6Ne+uf7/txs+YDN58ea45v3Pu2Zs37bxr/UM6ehwXurrWudBXYwQEA0hFwAIB0BBwAIB0BBwBIR8ABANIRcACAdAQcACCdbbNuAIDNaXX5uc7xo3/90JQ6eW2/dduvNGta5+AwG3ZwAIB0BBwAIB0BBwBIR8ABANIRcACAdJrfoiqlLEbEgYhYrbUuDd8SAMB4+nxN/NmIWIqIna8IO8dqrY+8urCUcjAiDkbE1ol2CQAwgmbAqbWullKufjnQlFIiIlZfo3YpIpZKKbsiYnmSjf401z91cugpSGhtrTZrartkcGcftW5btvQoYmK+/U+Pdo7/zWfvn1Inw7rtj2+cyjznb/+ZzvHLrnr7VPro0upxo3j/TTuaNWfq5VPoZONofganlHJ3RKyWUhbW/+poRFw7aFcAAGPos4NzRynl9og4VUpZjYhTEXFs8M4AAF6nXr+qodZ6z9CNAABMiq+JAwDpCDgAQDoCDgCQjoADAKQj4AAA6fT6FhVk8r53P9isefThn3qW5VRddc1Cs+ah4+8bvhF+4ocrz3eOf//E6Sl1MqzW65yUxcsv6hz/2iOHp9IHOdnBAQDSEXAAgHQEHAAgHQEHAEhHwAEA0hFwAIB0BBwAIB3n4AD0tPiL3ee23HLovVPqZFit1wnzwA4OAJCOgAMApCPgAADpCDgAQDoCDgCQjoADAKQj4AAA6Qg4AEA6pdY6+YuWsisilid+YQCAiN211pWuAjs4AEA6Ag4AkI6AAwCkI+AAAOkIOABAOgIOAJDOUAGnDHRdAIBmzhgq4CwMdF0AgGbOGOqgv60RcUlErEbE5CeYvGMRce2smwB+wjMJG89GeC5LnA03T9RaX+oq3DbE7OuTPj7EtYdQSnmpdSIiMD2eSdh4NtBz2es3JfiQMQCQjoBz1tKsGwD+H88kbDxz9VwO8hkcAIBZsoMDAKQj4AAA6Qg4AEA6Ag4AkI6As66Ucnsp5cCs+4DNrJSyUEo5UEo5WEpxIjpsEPP4M1LAOefUrBsA4nCcPS31WEQcnHEvwDlz9zNykJOMAV6nq2utqxGxWkp516ybAeaXHRwAIJ1Nu4Oz/l7iYkQsrf8fIzB7p9Y/e7MzIr4762aA+bVpA06t9eir/updEbERfokYbGZ3R8Rvr/95ro6Fh+Tm7mekX9UAAKTjMzgAQDoCDgCQjoADAKQj4AAA6Qg4AEA6Ag4AkI6AAwCkI+AAAOn8H2h+nSHlvodwAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 576x432 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "## Plot a random sample (used in the appendix)\n",
    "n_intervals = 7\n",
    "n_gen = n_intervals*num_factors\n",
    "facts = torch.tensor(syn_dataset.sample_factors(n_gen), dtype=torch.float)\n",
    "print(facts.shape)\n",
    "batch = syn_dataset.sample_observations_from_factors(facts, ret_torch=True)\n",
    "print(batch.shape, type(batch))\n",
    "#batch = batch.transpose(3,2).transpose(2, 1)\n",
    "plot_traversal(batch, n_row=n_intervals)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "interpretable_concepts",
   "language": "python",
   "name": "interpretable_concepts"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
